{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Segmenting from Prompts\n",
    "\n",
    "This notebook walks through predicting object segmentations from a provided prompt. It uses the `Predictor` class. It is modified from [original SAM GitHub repo](https://github.com/facebookresearch/segment-anything/).\n",
    "\n",
    "### Setup\n",
    "Necessary imports and helper functions for displaying points, boxes, and masks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "import mlx.core as mx\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_mask(mask, ax, random_color=False):\n",
    "    if random_color:\n",
    "        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
    "    else:\n",
    "        color = np.array([30/255, 144/255, 255/255, 0.6])\n",
    "    h, w = mask.shape[:2]\n",
    "    mask_image = np.array(mask).reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
    "    ax.imshow(mask_image)\n",
    "    \n",
    "def show_points(coords, labels, ax, marker_size=375):\n",
    "    pos_points = np.array(coords)[labels==1]\n",
    "    neg_points = np.array(coords)[labels==0]\n",
    "    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
    "    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   \n",
    "    \n",
    "def show_box(box, ax):\n",
    "    box = box.tolist()\n",
    "    x0, y0 = box[0], box[1]\n",
    "    w, h = box[2] - box[0], box[3] - box[1]\n",
    "    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image = cv2.imread('images/truck.jpg')\n",
    "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(image)\n",
    "plt.axis('on')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Selecting objects with SAM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from segment_anything.sam import load\n",
    "from segment_anything.predictor import SamPredictor\n",
    "\n",
    "sam_checkpoint = \"../sam-vit-base\"\n",
    "sam = load(sam_checkpoint)\n",
    "predictor = SamPredictor(sam)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Process the image to produce an image embedding by calling `SamPredictor.set_image`. `SamPredictor` remembers this embedding and will use it for subsequent mask prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.set_image(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To select the truck, choose a point on it. Points are input to the model in (x,y) format and come with labels 1 (foreground point) or 0 (background point). Multiple points can be input; here we use only one. The chosen point will be shown as a star on the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_point = mx.array([[500, 375]])\n",
    "input_label = mx.array([1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(image)\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('on')\n",
    "plt.show()  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Predict with `SamPredictor.predict`. The model returns masks, quality predictions for those masks, and low resolution mask logits that can be passed to the next iteration of prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, scores, logits = predictor.predict(\n",
    "    point_coords=input_point[None],\n",
    "    point_labels=input_label[None],\n",
    "    multimask_output=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With `multimask_output=True` (the default setting), SAM outputs 3 masks, where `scores` gives the model's own estimation of the quality of these masks. This setting is intended for ambiguous input prompts, and helps the model disambiguate different objects consistent with the prompt. When `False`, it will return a single mask. For ambiguous prompts such as a single point, it is recommended to use `multimask_output=True` even if only a single mask is desired; the best single mask can be chosen by picking the one with the highest score returned in `scores`. This will often result in a better mask."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(masks.shape[-1]):\n",
    "    mask = masks[..., i]\n",
    "    score = scores[..., i].item()\n",
    "    plt.figure(figsize=(10,10))\n",
    "    plt.imshow(image)\n",
    "    show_mask(mask[0], plt.gca())\n",
    "    show_points(input_point, input_label, plt.gca())\n",
    "    plt.title(f\"Mask {i+1}, Score: {score:.3f}\", fontsize=18)\n",
    "    plt.axis('off')\n",
    "    plt.show()  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Specifying a specific object with additional points"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The single input point is ambiguous, and the model has returned multiple objects consistent with it. To obtain a single object, multiple points can be provided. If available, a mask from a previous iteration can also be supplied to the model to aid in prediction. When specifying a single object with multiple prompts, a single mask can be requested by setting `multimask_output=False`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_point = mx.array([[500, 375], [1125, 625]])\n",
    "input_label = mx.array([1, 1])\n",
    "mask_input = logits[..., mx.argmax(scores)]  # Choose the model's best mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, _, _ = predictor.predict(\n",
    "    point_coords=input_point[None],\n",
    "    point_labels=input_label[None],\n",
    "    mask_input=mask_input[..., None],\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(image)\n",
    "show_mask(masks[0], plt.gca())\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To exclude the car and specify just the window, a background point (with label 0, here shown in red) can be supplied."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_point = mx.array([[500, 375], [1125, 625]])\n",
    "input_label = mx.array([1, 0])\n",
    "mask_input = logits[..., mx.argmax(scores)]  # Choose the model's best mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, _, _ = predictor.predict(\n",
    "    point_coords=input_point[None],\n",
    "    point_labels=input_label[None],\n",
    "    mask_input=mask_input[..., None],\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(image)\n",
    "show_mask(masks[0], plt.gca())\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Specifying a specific object with a box"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model can also take a box as input, provided in xyxy format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_box = mx.array([425, 600, 700, 875])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, _, _ = predictor.predict(\n",
    "    point_coords=None,\n",
    "    point_labels=None,\n",
    "    box=input_box[None, :],\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(image)\n",
    "show_mask(masks[0, ..., 0], plt.gca())\n",
    "show_box(input_box, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Combining points and boxes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Points and boxes may be combined, just by including both types of prompts to the predictor. Here this can be used to select just the trucks's tire, instead of the entire wheel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_box = mx.array([425, 600, 700, 875])\n",
    "input_point = mx.array([[575, 750]])\n",
    "input_label = mx.array([0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, _, _ = predictor.predict(\n",
    "    point_coords=input_point[None],\n",
    "    point_labels=input_label[None],\n",
    "    box=input_box,\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(image)\n",
    "show_mask(masks[0, ..., 0], plt.gca())\n",
    "show_box(input_box, plt.gca())\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batched prompt inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`SamPredictor` can take multiple input prompts for the same image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_boxes = mx.array([\n",
    "    [75, 275, 1725, 850],\n",
    "    [425, 600, 700, 875],\n",
    "    [1375, 550, 1650, 800],\n",
    "    [1240, 675, 1400, 750],\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, _, _ = predictor.predict(\n",
    "    point_coords=None,\n",
    "    point_labels=None,\n",
    "    box=input_boxes,\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(image)\n",
    "for mask in masks:\n",
    "    show_mask(mask, plt.gca(), random_color=True)\n",
    "for box in input_boxes:\n",
    "    show_box(box, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## End-to-end batched inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If all prompts are available in advance, it is possible to run SAM directly in an end-to-end fashion. This also allows batching over images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image1 = image  # truck.jpg from above\n",
    "image1_boxes = mx.array([\n",
    "    [75, 275, 1725, 850],\n",
    "    [425, 600, 700, 875],\n",
    "    [1375, 550, 1650, 800],\n",
    "    [1240, 675, 1400, 750],\n",
    "])\n",
    "\n",
    "image2 = cv2.imread('images/groceries.jpg')\n",
    "image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)\n",
    "image2_boxes = mx.array([\n",
    "    [450, 170, 520, 350],\n",
    "    [350, 190, 450, 350],\n",
    "    [500, 170, 580, 350],\n",
    "    [580, 170, 640, 350],\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Both images and prompts are input as mlx array that are already transformed to the correct frame. Inputs are packaged as a list over images, which each element is a dict that takes the following keys:\n",
    "* `image`: The input image as a mlx array in HWC format.\n",
    "* `original_size`: The size of the image before transforming for input to SAM, in (H, W) format.\n",
    "* `point_coords`: Batched coordinates of point prompts.\n",
    "* `point_labels`: Batched labels of point prompts.\n",
    "* `boxes`: Batched input boxes.\n",
    "* `mask_inputs`: Batched input masks.\n",
    "\n",
    "If a prompt is not present, the key can be excluded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from segment_anything.utils.transforms import ResizeLongestSide\n",
    "resize_transform = ResizeLongestSide(sam.vision_encoder.img_size)\n",
    "\n",
    "def prepare_image(image, transform, device):\n",
    "    image = transform.apply_image(image)\n",
    "    image = mx.array(image)\n",
    "    return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batched_input = [\n",
    "    {\n",
    "        'image': prepare_image(image1, resize_transform, sam),\n",
    "        'boxes': resize_transform.apply_boxes(image1_boxes, image1.shape[:2]),\n",
    "        'original_size': image1.shape[:2]\n",
    "    },\n",
    "    {\n",
    "        'image': prepare_image(image2, resize_transform, sam),\n",
    "        'boxes': resize_transform.apply_boxes(image2_boxes, image2.shape[:2]),\n",
    "        'original_size': image2.shape[:2]\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batched_output = sam(batched_input, multimask_output=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The output is a list over results for each input image, where list elements are dictionaries with the following keys:\n",
    "* `masks`: A batched mlx array of predicted binary masks, the size of the original image.\n",
    "* `iou_predictions`: The model's prediction of the quality for each mask.\n",
    "* `low_res_logits`: Low res logits for each mask, which can be passed back to the model as mask input on a later iteration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 2, figsize=(20, 20))\n",
    "\n",
    "ax[0].imshow(image1)\n",
    "for mask in batched_output[0]['masks']:\n",
    "    show_mask(np.array(mask), ax[0], random_color=True)\n",
    "for box in image1_boxes:\n",
    "    show_box(np.array(box), ax[0])\n",
    "ax[0].axis('off')\n",
    "\n",
    "ax[1].imshow(image2)\n",
    "for mask in batched_output[1]['masks']:\n",
    "    show_mask(np.array(mask), ax[1], random_color=True)\n",
    "for box in image2_boxes:\n",
    "    show_box(np.array(box), ax[1])\n",
    "ax[1].axis('off')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
